{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fix Poor Predictions from Comprehend Custom Text Classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import io\n",
    "import json\n",
    "import uuid\n",
    "import time\n",
    "import boto3\n",
    "import botocore\n",
    "\n",
    "# Amazon Python SDK clients\n",
    "sagemaker = boto3.client(\"sagemaker\", region)\n",
    "comprehend = boto3.client(\"comprehend\", region)\n",
    "a2i = boto3.client(\"sagemaker-a2i-runtime\")\n",
    "s3 = boto3.client(\"s3\", region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retrieve the `augmented_ai_flow_definition_arn` Created Previously"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r augmented_ai_flow_definition_arn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(augmented_ai_flow_definition_arn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# _Wait for the Comprehend Job to Complete from the Previous Section_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Retrieve the `comprehend_endpoint_arn` Deployed Previously"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r comprehend_endpoint_arn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "try:\n",
    "    comprehend_endpoint_arn\n",
    "except NameError:\n",
    "    print(\"*** PLEASE WAIT FOR THE Comprehend JOB TO FINISH IN THE PREVIOUS SECTION BEFORE CONTINUING ***\")\n",
    "    print(\"*** YOU WILL NEED TO RESTART THIS NOTEBOOK ONCE THE JOB FINISHES ***\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(comprehend_endpoint_arn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check the Confidence Score for Each Comprehend Prediction\n",
    "If < threshold, start the human loop.  You can integrate this type of logic into your application using the SDK.  In this case, we're using the Python SDK."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Use Comprehend to Predict Some Sample Reviews"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_reviews = [\"I enjoy this product\", \"I am unhappy with this product\", \"It is okay\", \"sometimes it works\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start a Human Loop When Comprehend Does Not Produce a Confident Prediction\n",
    "The human lmoop will engage our workforce for human review if the confidence of the Comprehend prediction is less than the provided confidence.\n",
    "\n",
    "![](img/augmented-ai-comprehend-predictions.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "human_loops_started = []\n",
    "\n",
    "CONFIDENCE_SCORE_THRESHOLD = 0.90\n",
    "\n",
    "for sample_review in sample_reviews:\n",
    "    # Call the Comprehend Custom model that we trained earlier\n",
    "    response = comprehend.classify_document(Text=sample_review, EndpointArn=comprehend_endpoint_arn)\n",
    "\n",
    "    star_rating = response[\"Classes\"][0][\"Name\"]\n",
    "    confidence_score = response[\"Classes\"][0][\"Score\"]\n",
    "\n",
    "    print(f'Processing sample_review: \"{sample_review}\"')\n",
    "\n",
    "    # Our condition for when we want to engage a human for review\n",
    "    if confidence_score < CONFIDENCE_SCORE_THRESHOLD:\n",
    "\n",
    "        humanLoopName = str(uuid.uuid4())\n",
    "        inputContent = {\"initialValue\": star_rating, \"taskObject\": sample_review}\n",
    "        start_loop_response = a2i.start_human_loop(\n",
    "            HumanLoopName=humanLoopName,\n",
    "            FlowDefinitionArn=augmented_ai_flow_definition_arn,\n",
    "            HumanLoopInput={\"InputContent\": json.dumps(inputContent)},\n",
    "        )\n",
    "\n",
    "        human_loops_started.append(humanLoopName)\n",
    "\n",
    "        print(\n",
    "            f\"Confidence score of {confidence_score} for star rating of {star_rating} is less than the threshold of {CONFIDENCE_SCORE_THRESHOLD}\"\n",
    "        )\n",
    "        print(f\"*** ==> Starting human loop with name: {humanLoopName}  \\n\")\n",
    "    else:\n",
    "        print(\n",
    "            f\"Confidence score of {confidence_score} for star rating of {star_rating} is above threshold of {CONFIDENCE_SCORE_THRESHOLD}\"\n",
    "        )\n",
    "        print(\"No human loop created. \\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Check Status of Human Loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "completed_human_loops = []\n",
    "for human_loop_name in human_loops_started:\n",
    "    resp = a2i.describe_human_loop(HumanLoopName=human_loop_name)\n",
    "    print(f\"HumanLoop Name: {human_loop_name}\")\n",
    "    print(f'HumanLoop Status: {resp[\"HumanLoopStatus\"]}')\n",
    "    print(f'HumanLoop Output Destination: {resp[\"HumanLoopOutput\"]}')\n",
    "    print(\"\")\n",
    "\n",
    "    if resp[\"HumanLoopStatus\"] == \"Completed\":\n",
    "        completed_human_loops.append(resp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Wait For Workers to Complete Their Human Loop Tasks\n",
    "\n",
    "Navigate to the link below and login with your email and password that you used when you set up the Private Workforce."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r augmented_ai_workteam_arn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(augmented_ai_workteam_arn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "workteam_name = augmented_ai_workteam_arn[augmented_ai_workteam_arn.rfind(\"/\") + 1 :]\n",
    "print(workteam_name)\n",
    "print(\"Navigate to the private worker portal and complete the human loop.\")\n",
    "print(\"Make sure you have invited yourself to the workteam and received the signup email.\")\n",
    "print(\"Note:  Check your spam filter if you have not received the email.\")\n",
    "print(\"\")\n",
    "print(\"https://\" + sagemaker.describe_workteam(WorkteamName=workteam_name)[\"Workteam\"][\"SubDomain\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Start Labeling\n",
    "\n",
    "<img src=\"img/augmented-comprehend-custom-start-working.png\" width=\"80%\" align=\"left\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Select Label\n",
    "\n",
    "<img src=\"img/augmented-comprehend-custom-select-label.png\" width=\"80%\" align=\"left\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loop is Completed\n",
    "\n",
    "<img src=\"img/augmented-comprehend-custom-finished-task.png\" width=\"80%\" align=\"left\">"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Verify the Human Loops are Completed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "completed_human_loops = []\n",
    "for human_loop_name in human_loops_started:\n",
    "    resp = a2i.describe_human_loop(HumanLoopName=human_loop_name)\n",
    "    print(f\"HumanLoop Name: {human_loop_name}\")\n",
    "    print(f'HumanLoop Status: {resp[\"HumanLoopStatus\"]}')\n",
    "    print(f'HumanLoop Output Destination: {resp[\"HumanLoopOutput\"]}')\n",
    "    print(\"\")\n",
    "    while resp[\"HumanLoopStatus\"] != \"Completed\":\n",
    "        print(f\"Waiting for HumanLoop to complete.\")\n",
    "        time.sleep(10)\n",
    "        resp = a2i.describe_human_loop(HumanLoopName=human_loop_name)\n",
    "    if resp[\"HumanLoopStatus\"] == \"Completed\":\n",
    "        completed_human_loops.append(resp)\n",
    "        print(f\"Completed!\")\n",
    "        print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# View Human Labels  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once the work is complete, Amazon A2I stores the results in the specified S3 bucket and sends a Cloudwatch Event.  Let's check the S3 contents."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import re\n",
    "import pprint\n",
    "\n",
    "pp = pprint.PrettyPrinter(indent=4)\n",
    "\n",
    "fixed_items = []\n",
    "\n",
    "for resp in completed_human_loops:\n",
    "    split_string = re.split(\"s3://\" + bucket + \"/\", resp[\"HumanLoopOutput\"][\"OutputS3Uri\"])\n",
    "    output_bucket_key = split_string[1]\n",
    "\n",
    "    response = s3.get_object(Bucket=bucket, Key=output_bucket_key)\n",
    "    content = response[\"Body\"].read().decode(\"utf-8\")\n",
    "    json_output = json.loads(content)\n",
    "    print(json_output)\n",
    "\n",
    "    input_content = json_output[\"inputContent\"]\n",
    "    human_answer = json_output[\"humanAnswers\"][0][\"answerContent\"]\n",
    "    fixed_item = {\"input_content\": input_content, \"human_answer\": human_answer}\n",
    "    fixed_items.append(fixed_item)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare the Data for Re-training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_fixed_items = pd.DataFrame(fixed_items)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_fixed_items.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Once finished, delete the Comprehend Custom Model Endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# comprehend.delete_endpoint(EndpointArn=comprehend_endpoint_arn)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%javascript\n",
    "Jupyter.notebook.save_checkpoint()\n",
    "Jupyter.notebook.session.delete();"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
